{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbd351bc-bd19-4a6f-ad9c-15c56a1a29da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1011aff1-f0b6-42da-b30b-bb819cb238ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualBlock(nn.Module):\n",
    "    def __init__(self, inchannel, outchannel, stride=1):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "        self.left = nn.Sequential(\n",
    "            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(outchannel),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(outchannel)\n",
    "        )\n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or inchannel != outchannel:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(outchannel)\n",
    "            )\n",
    "            \n",
    "    def forward(self, x):\n",
    "        out = self.left(x)\n",
    "        out = out + self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        return out\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, ResidualBlock, num_classes=5):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.inchannel = 64\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)\n",
    "        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)\n",
    "        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)      \n",
    "        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)        \n",
    "        self.fc = nn.Linear(14336, num_classes)\n",
    "        \n",
    "        \n",
    "    def make_layer(self, block, channels, num_blocks, stride):\n",
    "        strides = [stride] + [1] * (num_blocks - 1)\n",
    "        layers = []\n",
    "        for stride in strides:\n",
    "            layers.append(block(self.inchannel, channels, stride))\n",
    "            self.inchannel = channels\n",
    "        return nn.Sequential(*layers)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        out = F.avg_pool2d(out, 4)\n",
    "        out = out.view(out.size(0), -1)  # 展平成一维\n",
    "        #print(out.shape)  # 调试信息，打印输出尺寸\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "    \n",
    "def ResNet18():\n",
    "    return ResNet(ResidualBlock)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d9c26936-6832-43c3-9f67-5f82c51d65d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageClassifier:\n",
    "    def __init__(self, model_path, class_names, num_classes=5):\n",
    "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.model = self._load_model(model_path, num_classes)\n",
    "        self.class_names = class_names\n",
    "        self.transform = transforms.Compose([\n",
    "            transforms.Resize((135, 240)),\n",
    "            transforms.ToTensor(),\n",
    "        ])\n",
    "\n",
    "    def _load_model(self, model_path, num_classes):\n",
    "        model = ResNet(ResidualBlock, num_classes).to(self.device)\n",
    "        checkpoint = torch.load(model_path)\n",
    "        model.load_state_dict(checkpoint)\n",
    "        model.eval()\n",
    "        return model\n",
    "\n",
    "    def classify_image(self, image_path):\n",
    "        image = Image.open(image_path)\n",
    "        image = image.crop([300, 80, 1620, 950])\n",
    "        image = self.transform(image)\n",
    "        image = image.unsqueeze(0).to(self.device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = self.model(image)\n",
    "            probabilities = F.softmax(outputs, dim=1).squeeze().cpu().numpy()\n",
    "\n",
    "        return {self.class_names[i]: probabilities[i] for i in range(len(self.class_names))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2f2a31ce-2248-421e-93ed-a0b288f1f36f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tri: 0.00013\n",
      "T: 0.99972\n",
      "sin: 0.00001\n",
      "for: 0.00014\n",
      "dou: 0.00000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
      "  return F.conv2d(input, weight, bias, self.stride,\n"
     ]
    }
   ],
   "source": [
    "# 示例使用\n",
    "if __name__ == \"__main__\":\n",
    "    model_path = \"1resnet18_weights.pth\"\n",
    "    image_path = \"test/image000005.jpg\"\n",
    "    class_names = ['tri','T','sin','for','dou']\n",
    "    classifier = ImageClassifier(model_path, class_names)\n",
    "    probabilities = classifier.classify_image(image_path)\n",
    "    \n",
    "    for class_name, prob in probabilities.items():\n",
    "        print(f\"{class_name}: {prob:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
